#include "netlib/net/timer/timer_queue.h"

#include <sys/timerfd.h>
#include <unistd.h>

#include <functional>

#include "netlib/base/logger.h"
#include "netlib/net/event/event_loop.h"
#include "netlib/net/timer/timer.h"

namespace netlib::net::detail {

int CreateTimerFd() {
	int timerfd = ::timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC);
	if (timerfd < 0) {
		LOG_SYSFATAL << "Failed in timerfd_create";
	}
	return timerfd;
}

struct timespec HowMuchTimeFromNow(Timestamp when) {
	int64_t microseconds =
	    when.MicroSecondsSinceEpoch() - Timestamp::Now().MicroSecondsSinceEpoch();
	if (microseconds < 100) {
		microseconds = 100;
	}
	struct timespec ts;
	ts.tv_sec = static_cast<time_t>(microseconds / Timestamp::kMicroSecondsPerSecond);
	ts.tv_nsec = static_cast<int64_t>((microseconds % Timestamp::kMicroSecondsPerSecond) * 1000);
	return ts;
}

void ResetTimerfd(int timerfd, Timestamp expiration) {
	// wake up loop by timerfd_settime()
	struct itimerspec new_value;
	struct itimerspec old_value;
	memset(&new_value, '\0', sizeof new_value);
	memset(&old_value, '\0', sizeof old_value);
	new_value.it_value = HowMuchTimeFromNow(expiration);
	int ret = ::timerfd_settime(timerfd, 0, &new_value, &old_value);
	if (ret) {
		LOG_SYSERR << "timerfd_settime()";
	}
}

void ReadTimerfd(int timerfd, Timestamp now) {
	uint64_t howmany;
	ssize_t n = ::read(timerfd, &howmany, sizeof howmany);
	LOG_TRACE << "TimerQueue::handleRead() " << howmany << " at " << now.ToString();
	if (n != sizeof howmany) {
		LOG_ERROR << "TimerQueue::handleRead() reads " << n << " bytes instead of 8";
	}
}

} // namespace netlib::net::detail

netlib::AtomicInt64 netlib::net::Timer::s_num_created;

netlib::net::TimerQueue::TimerQueue(EventLoop* loop) :
    loop_(loop), kTimerfd(detail::CreateTimerFd()), timer_channel_(loop, kTimerfd) {
	timer_channel_.SetReadCallback([this](Timestamp) { HandleRead(); });
	// timer_channel_.SetReadCallback(std::bind(&TimerQueue::HandleRead, this));
	timer_channel_.EnableReading();
}

netlib::net::TimerQueue::~TimerQueue() {
	timer_channel_.DisableAll();
	timer_channel_.Remove();
	close(kTimerfd);
	for (const Entry& entry : timers_) {
		delete entry.second;
	}
}

std::vector<netlib::net::TimerQueue::Entry> netlib::net::TimerQueue::GetExpired(Timestamp now) {
	std::vector<Entry> res;
	TimerIndex sentry{now, INT64_MAX};
	auto it = timers_.lower_bound(sentry);
	assert(it == timers_.end() || now < it->first.time_);
	std::copy(timers_.begin(), it, back_inserter(res));
	timers_.erase(timers_.begin(), it);
	return res;
}

void netlib::net::TimerQueue::Reset(const std::vector<Entry>& expired, Timestamp time) {
	for (const Entry& entry : expired) {
		if (entry.second->Repeat() && cancels_.find(entry.first) == cancels_.end()) {
			entry.second->Restart(time);
			Insert(entry.second);
		} else {
			delete entry.second;
		}
	}

	if (!timers_.empty()) {
		Timestamp next_expired = timers_.begin()->second->Expiration();
		detail::ResetTimerfd(kTimerfd, next_expired);
	}
}

bool netlib::net::TimerQueue::Insert(Timer* timer) {
	loop_->AssertInLoopThread();
	auto it = timers_.begin();
	bool res = (it == timers_.end() || timer->Expiration() < it->first.time_);
	bool inserted;
	// std::tie(it, inserted) =
	//     timers_.emplace(TimerIndex(timer), timer);
	std::tie(it, inserted) = timers_.insert({TimerIndex(timer), timer});
	assert(inserted);
	return res;
}

netlib::net::TimerIndex
netlib::net::TimerQueue::AddTimer(const TimerCallback& cb, Timestamp when, double interval) {
	auto* timer = new Timer(cb, when, interval);
	loop_->RunInLoop([this, timer] { AddTimerInLoop(timer); });
	return TimerIndex(timer);
}

void netlib::net::TimerQueue::Cancel(TimerIndex idx) {
	loop_->RunInLoop([this, idx] { CancelInLoop(idx); });
}

void netlib::net::TimerQueue::AddTimerInLoop(Timer* timer) {
	loop_->AssertInLoopThread();
	bool is_earliest_change = Insert(timer);
	if (is_earliest_change) {
		detail::ResetTimerfd(kTimerfd, timer->Expiration());
	}
}

void netlib::net::TimerQueue::HandleRead() {
	loop_->AssertInLoopThread();
	Timestamp now(Timestamp::Now());
	detail::ReadTimerfd(kTimerfd, now);
	std::vector<Entry> expired(GetExpired(now));
	is_calling_expired_ = true;
	cancels_.clear();
	for (const Entry& entry : expired) {
		entry.second->Run();
	}
	is_calling_expired_ = false;
	Reset(expired, now);
}

void netlib::net::TimerQueue::CancelInLoop(TimerIndex idx) {
	loop_->AssertInLoopThread();
	auto it = timers_.find(idx);
	if (it != timers_.end()) {
		delete it->second;
		timers_.erase(it);
	} else if (is_calling_expired_) {
		cancels_.insert(*it);
	}
}
